{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow as tf\n",
    "\n",
    "# Import tf_text to load the ops used by the tokenizer saved model\n",
    "import tensorflow_text  # pylint: disable=unused-import"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.1 Positional Encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(1, 2048, 512), dtype=float32, numpy=\n",
       "array([[[ 0.0000000e+00,  1.0000000e+00,  0.0000000e+00, ...,\n",
       "          1.0000000e+00,  0.0000000e+00,  1.0000000e+00],\n",
       "        [ 8.4147096e-01,  5.4030228e-01,  8.2185620e-01, ...,\n",
       "          1.0000000e+00,  1.0366329e-04,  1.0000000e+00],\n",
       "        [ 9.0929741e-01, -4.1614684e-01,  9.3641472e-01, ...,\n",
       "          1.0000000e+00,  2.0732658e-04,  1.0000000e+00],\n",
       "        ...,\n",
       "        [ 1.7589758e-01, -9.8440850e-01, -1.8608274e-01, ...,\n",
       "          9.7595036e-01,  2.1040717e-01,  9.7761387e-01],\n",
       "        [-7.3331332e-01, -6.7989087e-01,  7.0149130e-01, ...,\n",
       "          9.7592694e-01,  2.1050851e-01,  9.7759205e-01],\n",
       "        [-9.6831930e-01,  2.4971525e-01,  9.8535496e-01, ...,\n",
       "          9.7590351e-01,  2.1060985e-01,  9.7757018e-01]]], dtype=float32)>"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_angles(pos, i, dim):\n",
    "    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(dim))\n",
    "    return pos * angle_rates\n",
    "\n",
    "def positional_encoding(position, dim):\n",
    "    angle_rads = get_angles(np.arange(position)[:, np.newaxis],\n",
    "                          np.arange(dim)[np.newaxis, :],\n",
    "                          dim)\n",
    "\n",
    "    # apply sin to even indices in the array; 2i\n",
    "    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])\n",
    "\n",
    "    # apply cos to odd indices in the array; 2i+1\n",
    "    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])\n",
    "\n",
    "    pos_encoding = angle_rads[np.newaxis, ...]\n",
    "\n",
    "    return tf.cast(pos_encoding, dtype=tf.float32)\n",
    "\n",
    "\n",
    "n, d = 2048, 512\n",
    "pos_encoding = positional_encoding(n, d)\n",
    "pos_encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.2. Masking Routines for Handling Padding and to Remove Visibility of Future Words in Target "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[[[0. 0. 1. 1. 0.]]]\n",
      "\n",
      "\n",
      " [[[0. 0. 0. 1. 1.]]]\n",
      "\n",
      "\n",
      " [[[1. 0. 0. 0. 0.]]]], shape=(3, 1, 1, 5), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0. 1. 1.]\n",
      " [0. 0. 1.]\n",
      " [0. 0. 0.]], shape=(3, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "def create_padding_mask(seq):\n",
    "    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)\n",
    "\n",
    "    # add extra dimensions to add the padding\n",
    "    # to the attention logits.\n",
    "    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)\n",
    "\n",
    "def create_look_ahead_mask(size):\n",
    "    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)\n",
    "    return mask  # (seq_len, seq_len)\n",
    "\n",
    "x = tf.constant([[1, 2, 0, 0, 1], [1, 6, 7, 0, 0], [0, 1, 1, 4, 5]])\n",
    "print(create_padding_mask(x))\n",
    "print(create_look_ahead_mask(3))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.3. Scaled Dot Product for Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention_scaled_dot(Q, K, V, mask):\n",
    "    qk = tf.matmul(Q, K, transpose_b=True)\n",
    "    _dim_ = tf.cast(tf.shape(K)[-1], tf.float32)\n",
    "    scaled_qk = qk /tf.math.sqrt(_dim_)\n",
    "\n",
    "    if mask is not None:\n",
    "        scaled_qk += (mask * -1e9)\n",
    "\n",
    "    attention_wts = tf.nn.softmax(scaled_qk, axis=-1) \n",
    "\n",
    "    out = tf.matmul(attention_wts, V)\n",
    "\n",
    "    return out, attention_wts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.4. Multihead Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([1, 30, 512]), TensorShape([1, 8, 30, 30]))"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tensorflow.keras.layers import Layer\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "class multi_head_attention(Layer):\n",
    "    \n",
    "    def __init__(self,*, dim, num_heads):\n",
    "        super(multi_head_attention, self).__init__()\n",
    "        self.num_heads = num_heads\n",
    "        self.dim = dim\n",
    "\n",
    "        assert self.dim % self.num_heads == 0\n",
    "\n",
    "        self.head_dim = self.dim // self.num_heads\n",
    "\n",
    "        self.Wq = layers.Dense(self.dim)\n",
    "        self.Wk = layers.Dense(self.dim)\n",
    "        self.Wv = layers.Dense(self.dim)\n",
    "\n",
    "        self.dense = layers.Dense(self.dim)\n",
    "\n",
    "    def split_heads(self, x, batch_size):\n",
    "        \"\"\"Split the last dimension into (num_heads, depth).\n",
    "        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)\n",
    "        \"\"\"\n",
    "        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))\n",
    "\n",
    "        return tf.transpose(x, perm=[0, 2, 1, 3])\n",
    "\n",
    "    def call(self, V, K, Q, mask):\n",
    "        \n",
    "        batch_size = tf.shape(Q)[0]\n",
    "\n",
    "        Q = self.Wq(Q)  # (batch_size, seq_len, dim)\n",
    "        K = self.Wk(K)  # (batch_size, seq_len, dim)\n",
    "        V = self.Wv(V)  # (batch_size, seq_len, dim)\n",
    "\n",
    "        Q = self.split_heads(Q, batch_size)  # (batch_size, num_heads, seq_len_q, head_dim)\n",
    "        K = self.split_heads(K, batch_size)  # (batch_size, num_heads, seq_len_k, head_dim)\n",
    "        V = self.split_heads(V, batch_size)  # (batch_size, num_heads, seq_len_v, head_dim)\n",
    "           \n",
    "        scaled_attention, attention_weights = attention_scaled_dot(\n",
    "            Q, K, V, mask)\n",
    "\n",
    "        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, head_dim\n",
    "        concat_attention = tf.reshape(scaled_attention,\n",
    "                                      (batch_size, -1, self.dim))  # (batch_size, seq_len_q, dim)\n",
    "\n",
    "        output = self.dense(concat_attention)  # (batch_size, seq_len_q, dim)\n",
    "\n",
    "        return output, attention_weights\n",
    "\n",
    "# MUltiheaded self attention\n",
    "mha_layer = multi_head_attention(dim=512, num_heads=8)\n",
    "x = tf.random.uniform((1, 30, 512))  # (batch_size, sequence_len,dim)\n",
    "out, attn = mha_layer(V=x, K=x, Q=x, mask=None)\n",
    "out.shape, attn.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.5. Pointwise Feed-Forward Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def pointwise_mlp(dim, hidden_dim,activation='relu'):\n",
    "    return tf.keras.Sequential([\n",
    "      layers.Dense(hidden_dim, activation=activation),  # (batch_size, seq_len, hidden)\n",
    "      layers.Dense(dim)  # (batch_size, seq_len, dim)\n",
    "  ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.6. The Encoder Layer Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encoder Shape:(32, 30, 512)\n"
     ]
    }
   ],
   "source": [
    "class encoder_layer(Layer):\n",
    "    def __init__(self,*, dim, num_heads, hidden_dim, dropout=0.1):\n",
    "        super(encoder_layer, self).__init__()\n",
    "\n",
    "        self.mha = multi_head_attention(dim=dim, num_heads=num_heads)\n",
    "        self.mlp = pointwise_mlp(dim,hidden_dim=hidden_dim)\n",
    "\n",
    "        self.layernorm_1 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.layernorm_2 = layers.LayerNormalization(epsilon=1e-6)\n",
    "\n",
    "        self.dropout_1 = layers.Dropout(dropout)\n",
    "        self.dropout_2 = layers.Dropout(dropout)\n",
    "\n",
    "    def call(self, x, training, mask):\n",
    "        # Self Attention   \n",
    "        attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, dim)\n",
    "        attn_output = self.dropout_1(attn_output, training=training)\n",
    "        out_1 = self.layernorm_1(x + attn_output)  # (batch_size, input_seq_len, dim)\n",
    "\n",
    "        mlp_output = self.mlp(out_1)  # (batch_size, input_seq_len, dim)\n",
    "        mlp_output = self.dropout_2(mlp_output, training=training)\n",
    "        out_2 = self.layernorm_2(out_1 + mlp_output)  # (batch_size, input_seq_len, dim)\n",
    "\n",
    "        return out_2\n",
    "    \n",
    "enc_layer = encoder_layer(dim=512, num_heads=8, hidden_dim=2048)\n",
    "\n",
    "out_e = enc_layer(tf.random.uniform((32, 30, 512)), False, None)\n",
    "\n",
    "print(f\"Encoder Shape:{out_e.shape}\")  # (batch_size, input_seq_len, dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.7. The Decoder Layer Illustration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decoder output shape (32, 30, 512)\n"
     ]
    }
   ],
   "source": [
    "class decoder_layer(Layer):\n",
    "    def __init__(self,*, dim, num_heads, hidden_dim, dropout=0.1):\n",
    "        super(decoder_layer, self).__init__()\n",
    "\n",
    "        self.mha_1 = multi_head_attention(dim=dim, num_heads=num_heads) # For self attention\n",
    "        self.mha_2 = multi_head_attention(dim=dim, num_heads=num_heads) # For Cross attention\n",
    "\n",
    "        self.mlp = pointwise_mlp(dim,hidden_dim=hidden_dim)\n",
    "\n",
    "        self.layernorm_1 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.layernorm_2 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.layernorm_3 = layers.LayerNormalization(epsilon=1e-6)\n",
    "\n",
    "        self.dropout_1 = layers.Dropout(dropout)\n",
    "        self.dropout_2 = layers.Dropout(dropout)\n",
    "        self.dropout_3 = layers.Dropout(dropout)\n",
    "\n",
    "    def call(self, x, encoder_out, training, look_ahead_mask, padding_mask):\n",
    "    \n",
    "        # Self attention\n",
    "        attn_1, attn_wts_block_1 = self.mha_1(x, x, x, mask=look_ahead_mask)  # (batch_size, target_seq_len, dim)\n",
    "        attn_1 = self.dropout_1(attn_1, training=training)\n",
    "        out_1 = self.layernorm_1(attn_1 + x)\n",
    "        # Cross attention \n",
    "        attn_2, attn_wts_block_2 = self.mha_2(encoder_out,encoder_out,out_1, mask=padding_mask)  # (batch_size, target_seq_len, dim)\n",
    "        attn_2 = self.dropout_2(attn_2, training=training)\n",
    "        out_2 = self.layernorm_2(attn_2 + out_1)  # (batch_size, target_seq_len, dim)\n",
    "        # Feed forward MLP \n",
    "        mlp_output = self.mlp(out_2)  # (batch_size, target_seq_len, dim)\n",
    "        mlp_output = self.dropout_3(mlp_output, training=training)\n",
    "        out_3 = self.layernorm_3(mlp_output + out_2)  # (batch_size, target_seq_len, dim)\n",
    "\n",
    "        return out_3, attn_wts_block_1, attn_wts_block_2\n",
    "    \n",
    "dec_layer = decoder_layer(dim=512, num_heads=8, hidden_dim=2048)\n",
    "out_d, _, _ = dec_layer(tf.random.uniform((32, 30, 512)), out_e, False, None, None)\n",
    "print(\"Decoder output shape\",out_d.shape)  # (batch_size, target_seq_len, dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.8. Define Encoder as Multiple Encoder Layer Stack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encoder output shape: (32, 30, 512)\n"
     ]
    }
   ],
   "source": [
    "class encoder(Layer):\n",
    "    def __init__(self,*, num_layers,dim, num_heads, hidden_dim, input_vocab_size, dropout=0.1,max_tokens=2048):\n",
    "        super(encoder, self).__init__()\n",
    "\n",
    "        self.dim = dim\n",
    "        self.num_layers = num_layers\n",
    "        self.max_tokens = max_tokens\n",
    "        self.num_heads = num_heads\n",
    "        self.hidden_dim = hidden_dim\n",
    "\n",
    "        self.embedding = layers.Embedding(input_vocab_size, dim)\n",
    "        self.pos_encoding = positional_encoding(self.max_tokens, self.dim)\n",
    "\n",
    "        self.encoder_layers = [\n",
    "            encoder_layer(dim=self.dim, num_heads=self.num_heads, hidden_dim=hidden_dim,dropout=dropout)\n",
    "            for _ in range(self.num_layers)]\n",
    "\n",
    "        self.dropout = layers.Dropout(dropout)\n",
    "    \n",
    "    def call(self, x, training, mask):\n",
    "\n",
    "        input_seq_len = tf.shape(x)[1]\n",
    "\n",
    "        x = self.embedding(x)  # (batch_size, input_seq_len, dim)\n",
    "        x *= tf.math.sqrt(tf.cast(self.dim, tf.float32))\n",
    "        x += self.pos_encoding[:, :input_seq_len, :]\n",
    "        x = self.dropout(x, training=training)\n",
    "\n",
    "        for i in range(self.num_layers):\n",
    "            x = self.encoder_layers[i](x, training, mask)\n",
    "\n",
    "        return x  # (batch_size, input_seq_len, dim)\n",
    "    \n",
    "enc = encoder(num_layers=6, dim=512, num_heads=8,\n",
    "                         hidden_dim=2048,input_vocab_size=1000)\n",
    "\n",
    "X = tf.random.uniform((32, 30), dtype=tf.int64, minval=0, maxval=200)\n",
    "\n",
    "encoder_out = enc(X, training=False, mask=None)\n",
    "\n",
    "print(f\"Encoder output shape: {encoder_out.shape}\")  # (batch_size, input_seq_len, dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 4-7.9. Decoder as a Stack of Decoder Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decoder output shape: (32, 45, 512)\n",
      "Cross Attention shape :(32, 8, 45, 30)\n"
     ]
    }
   ],
   "source": [
    "class decoder(Layer):\n",
    "    def __init__(self,*, num_layers, dim, num_heads, hidden_dim, target_vocab_size, dropout=0.1,max_tokens=2048):\n",
    "        super(decoder, self).__init__()\n",
    "\n",
    "        self.dim = dim\n",
    "        self.num_layers = num_layers\n",
    "        self.max_tokens = max_tokens\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.embedding = layers.Embedding(target_vocab_size, self.dim)\n",
    "        self.pos_encoding = positional_encoding(self.max_tokens, self.dim)\n",
    "\n",
    "        self.decoder_layers = [\n",
    "            decoder_layer(dim=dim, num_heads=num_heads, hidden_dim=hidden_dim,dropout=dropout)\n",
    "            for _ in range(num_layers)]\n",
    "        \n",
    "        self.dropout = layers.Dropout(dropout)\n",
    "\n",
    "    def call(self, x, encoder_out, training, look_ahead_mask, padding_mask):\n",
    "\n",
    "        output_seq_len = tf.shape(x)[1]\n",
    "        \n",
    "        attention_wts_dict = {}\n",
    "\n",
    "        x = self.embedding(x)  # (batch_size, target_seq_len, dim)\n",
    "        x *= tf.math.sqrt(tf.cast(self.dim, tf.float32))\n",
    "        x += self.pos_encoding[:, :output_seq_len, :]\n",
    "\n",
    "        x = self.dropout(x, training=training)\n",
    "\n",
    "        for i in range(self.num_layers):\n",
    "            x, block_1, block_2 = self.decoder_layers[i](x, encoder_out, training,\n",
    "                                                 look_ahead_mask, padding_mask)\n",
    "\n",
    "            attention_wts_dict[f'decoder_layer{i}_block1'] = block_1\n",
    "            attention_wts_dict[f'decoder_layer{i}_block2'] = block_2\n",
    "\n",
    "        # x.shape == (batch_size, target_seq_len, dim)\n",
    "        return x, attention_wts_dict\n",
    "    \n",
    "dec = decoder(num_layers=6, dim=512, num_heads=8,\n",
    "                         hidden_dim=2048, target_vocab_size=1200)\n",
    "X_target = tf.random.uniform((32, 45), dtype=tf.int64, minval=0, maxval=200)\n",
    "\n",
    "out_decoder, attn_wts_dict = dec(X_target,\n",
    "                              encoder_out=encoder_out,\n",
    "                              training=False,\n",
    "                              look_ahead_mask=None,\n",
    "                              padding_mask=None)\n",
    "\n",
    "print(f\"Decoder output shape: {out_decoder.shape}\")\n",
    "print(f\"Cross Attention shape :{attn_wts_dict['decoder_layer0_block2'].shape}\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4-7-10. Putting it all together to create TRANSFORMER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transformer Output :(32, 45, 1200)\n",
      "Model: \"transformer_16\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " encoder_20 (encoder)        multiple                  9969152   \n",
      "                                                                 \n",
      " decoder_19 (decoder)        multiple                  13226496  \n",
      "                                                                 \n",
      " dense_1256 (Dense)          multiple                  615600    \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 23,811,248\n",
      "Trainable params: 23,811,248\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras import Model\n",
    "\n",
    "class transformer(Model):\n",
    "    def __init__(self,*, num_layers, dim, num_heads, hidden_dim, input_vocab_size,\n",
    "               target_vocab_size, dropout=0.1,max_tokens_input=20,max_tokens_output=20):\n",
    "        super(transformer,self).__init__()\n",
    "        self.encoder = encoder(num_layers=num_layers, dim=dim,\n",
    "                               num_heads=num_heads, hidden_dim=hidden_dim,\n",
    "                               input_vocab_size=input_vocab_size, dropout=dropout,max_tokens=max_tokens_input)\n",
    "\n",
    "        self.decoder = decoder(num_layers=num_layers, dim=dim,\n",
    "                               num_heads=num_heads, hidden_dim=hidden_dim,\n",
    "                               target_vocab_size=target_vocab_size, dropout=dropout,max_tokens=max_tokens_output)\n",
    "\n",
    "        self.final_layer = layers.Dense(target_vocab_size)\n",
    "\n",
    "    def call(self, inputs, training):\n",
    "        \n",
    "        input_, target_ = inputs\n",
    "\n",
    "        padding_mask, look_ahead_mask = self.create_masks(input_,target_)\n",
    "\n",
    "        encoder_output = self.encoder(input_, training, padding_mask)  # (batch_size, inp_seq_len, dim)\n",
    "\n",
    "        decoder_output, attn_wts_dict = self.decoder(\n",
    "            target, encoder_output, training, look_ahead_mask, padding_mask)\n",
    "\n",
    "        final_output = self.final_layer(decoder_output)  # (batch_size, target_seq_len, target_vocab_size)\n",
    "\n",
    "        return final_output, attn_wts_dict\n",
    "    \n",
    "    def create_masks(self, input_, target_):\n",
    "        padding_mask = create_padding_mask(input_)\n",
    "\n",
    "        look_ahead_mask = create_look_ahead_mask(tf.shape(target_)[1])\n",
    "        decoder_target_padding_mask = create_padding_mask(target_)\n",
    "        look_ahead_mask = tf.maximum(decoder_target_padding_mask, look_ahead_mask)\n",
    "\n",
    "        return padding_mask, look_ahead_mask\n",
    "\n",
    "model = transformer(\n",
    "    num_layers=3, dim=512, num_heads=8, hidden_dim=2048,\n",
    "    input_vocab_size=1000, target_vocab_size=1200,max_tokens_input=30,max_tokens_output=45)\n",
    "\n",
    "input = tf.random.uniform((32, 30), dtype=tf.int64, minval=0, maxval=200)\n",
    "target = tf.random.uniform((32, 45), dtype=tf.int64, minval=0, maxval=200)\n",
    "\n",
    "model_output, _ = model([input,target], training=False)\n",
    "\n",
    "print(f\"Transformer Output :{model_output.shape}\")# (batch_size, target_seq_len, target_vocab_size)\n",
    "print(model.summary())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
